from __future__ import annotations

import datetime as dt
import functools
import inspect
import os
import sys
import time as time_module
import uuid
from collections.abc import Awaitable, Callable, Generator
from collections.abc import Generator as TypingGenerator
from time import gmtime as orig_gmtime
from time import struct_time
from types import TracebackType
from typing import Any, TypeAlias, TypeVar, cast, overload
from unittest import TestCase
from zoneinfo import ZoneInfo

import _time_machine

# time.clock_gettime and time.CLOCK_REALTIME not always available
# e.g. on builds against old macOS = official Python.org installer
try:
    from time import CLOCK_REALTIME
except ImportError:
    # Dummy value that won't compare equal to any value
    CLOCK_REALTIME = sys.maxsize  # type: ignore[misc]

try:
    from time import tzset

    HAVE_TZSET = True
except ImportError:  # pragma: no cover
    # Windows
    HAVE_TZSET = False

try:
    from dateutil.parser import parse as parse_datetime

    HAVE_DATEUTIL = True
except ImportError:  # pragma: no cover
    HAVE_DATEUTIL = False

try:
    import pytest
except ImportError:  # pragma: no cover
    HAVE_PYTEST = False
else:
    HAVE_PYTEST = True

NANOSECONDS_PER_SECOND = 1_000_000_000

# Windows' time epoch is not unix epoch but in 1601. This constant helps us
# translate to it.
_system_epoch = orig_gmtime(0)
SYSTEM_EPOCH_TIMESTAMP_NS = int(
    dt.datetime(
        _system_epoch.tm_year,
        _system_epoch.tm_mon,
        _system_epoch.tm_mday,
        _system_epoch.tm_hour,
        _system_epoch.tm_min,
        _system_epoch.tm_sec,
        tzinfo=dt.timezone.utc,
    ).timestamp()
    * NANOSECONDS_PER_SECOND
)

DestinationBaseType: TypeAlias = (
    int | float | dt.datetime | dt.timedelta | dt.date | str
)
DestinationType: TypeAlias = (
    DestinationBaseType
    | Callable[[], DestinationBaseType]
    | TypingGenerator[DestinationBaseType, None, None]
)

_F = TypeVar("_F", bound=Callable[..., Any])
_AF = TypeVar("_AF", bound=Callable[..., Awaitable[Any]])
TestCaseType = TypeVar("TestCaseType", bound=type[TestCase])

# copied from typeshed:
_TimeTuple = tuple[int, int, int, int, int, int, int, int, int]


def extract_timestamp_tzname(
    destination: DestinationType,
) -> tuple[float, str | None]:
    dest: DestinationBaseType
    if isinstance(destination, Generator):
        dest = next(destination)
    elif callable(destination):
        dest = destination()
    else:
        dest = destination

    timestamp: float
    tzname: str | None = None
    if isinstance(dest, int):
        timestamp = float(dest)
    elif isinstance(dest, float):
        timestamp = dest
    elif isinstance(dest, dt.datetime):
        if isinstance(dest.tzinfo, ZoneInfo):
            tzname = dest.tzinfo.key
        elif dest.tzinfo == dt.timezone.utc:
            tzname = "UTC"
        elif dest.tzinfo is None:
            dest = dest.replace(tzinfo=dt.timezone.utc)
        timestamp = dest.timestamp()
    elif isinstance(dest, dt.timedelta):
        timestamp = time_module.time() + dest.total_seconds()
    elif isinstance(dest, dt.date):
        timestamp = dt.datetime.combine(
            dest, dt.time(0, 0), tzinfo=dt.timezone.utc
        ).timestamp()
    elif isinstance(dest, str):
        try:
            parsed = dt.datetime.fromisoformat(dest)
        except ValueError as exc:
            if HAVE_DATEUTIL:
                try:
                    parsed = parse_datetime(dest)
                except ValueError as dateutil_exc:
                    raise dateutil_exc from None
            else:
                raise exc

        timestamp = parsed.timestamp()
    else:
        raise TypeError(f"Unsupported destination {dest!r}")

    return timestamp, tzname


class Traveller:
    def __init__(
        self,
        destination_timestamp: float,
        destination_tzname: str | None,
        tick: bool,
    ) -> None:
        self._destination_timestamp_ns = int(
            destination_timestamp * NANOSECONDS_PER_SECOND
        )
        self._destination_tzname = destination_tzname
        self._tick = tick
        self._requested = False

    def time(self) -> float:
        return self.time_ns() / NANOSECONDS_PER_SECOND

    def time_ns(self) -> int:
        if not self._tick:
            return self._destination_timestamp_ns

        base = SYSTEM_EPOCH_TIMESTAMP_NS + self._destination_timestamp_ns
        now_ns: int = _time_machine.original_time_ns()

        if not self._requested:
            self._requested = True
            self._real_start_timestamp_ns = now_ns
            return base

        return base + (now_ns - self._real_start_timestamp_ns)

    def shift(self, delta: dt.timedelta | int | float) -> None:
        if isinstance(delta, dt.timedelta):
            total_seconds = delta.total_seconds()
        elif isinstance(delta, (int, float)):
            total_seconds = delta
        else:
            raise TypeError(f"Unsupported type for delta argument: {delta!r}")

        self._destination_timestamp_ns += int(total_seconds * NANOSECONDS_PER_SECOND)

    def move_to(
        self,
        destination: DestinationType,
        tick: bool | None = None,
    ) -> None:
        self._stop()
        timestamp, self._destination_tzname = extract_timestamp_tzname(destination)
        self._destination_timestamp_ns = int(timestamp * NANOSECONDS_PER_SECOND)
        self._requested = False
        self._start()
        if tick is not None:
            self._tick = tick

    def _start(self) -> None:
        if HAVE_TZSET and self._destination_tzname is not None:
            self._orig_tz = os.environ.get("TZ")
            os.environ["TZ"] = self._destination_tzname
            tzset()

    def _stop(self) -> None:
        if HAVE_TZSET and self._destination_tzname is not None:
            if self._orig_tz is None:
                del os.environ["TZ"]
            else:
                os.environ["TZ"] = self._orig_tz
            tzset()


traveller_stack: list[Traveller] = []
original_uuid_generate_time_safe = None
original_uuid_uuid_create = None


class travel:
    def __init__(self, destination: DestinationType, *, tick: bool = True) -> None:
        self.destination_timestamp, self.destination_tzname = extract_timestamp_tzname(
            destination
        )
        self.tick = tick

    def start(self) -> Traveller:
        if not traveller_stack:
            _time_machine.patch()

            # During time travel, patch the uuid module's time-based generation function to
            # None, which makes it use time.time(). Otherwise it makes a system call to
            # find the current datetime. The time it finds is stored in generated UUID1
            # values.
            global original_uuid_generate_time_safe
            global original_uuid_uuid_create

            original_uuid_generate_time_safe = uuid._generate_time_safe  # type: ignore[attr-defined]
            original_uuid_uuid_create = uuid._UuidCreate  # type: ignore[attr-defined]
            uuid._generate_time_safe = None  # type: ignore[attr-defined]
            uuid._UuidCreate = None  # type: ignore[attr-defined]

        traveller = Traveller(
            destination_timestamp=self.destination_timestamp,
            destination_tzname=self.destination_tzname,
            tick=self.tick,
        )
        traveller_stack.append(traveller)
        traveller._start()

        return traveller

    def stop(self) -> None:
        traveller_stack.pop()._stop()

        if not traveller_stack:
            _time_machine.unpatch()

            global original_uuid_generate_time_safe
            global original_uuid_uuid_create

            uuid._generate_time_safe = original_uuid_generate_time_safe  # type: ignore[attr-defined]
            uuid._UuidCreate = original_uuid_uuid_create  # type: ignore[attr-defined]
            original_uuid_generate_time_safe = None
            original_uuid_uuid_create = None

    def __enter__(self) -> Traveller:
        return self.start()

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        self.stop()

    async def __aenter__(self) -> Traveller:
        return self.start()

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        self.stop()

    @overload
    def __call__(self, wrapped: TestCaseType) -> TestCaseType:  # pragma: no cover
        ...

    @overload
    def __call__(self, wrapped: _AF) -> _AF:  # pragma: no cover
        ...

    @overload
    def __call__(self, wrapped: _F) -> _F:  # pragma: no cover
        ...

    # 'Any' below is workaround for Mypy error:
    # Overloaded function implementation does not accept all possible arguments
    # of signature
    def __call__(
        self, wrapped: TestCaseType | _AF | _F | Any
    ) -> TestCaseType | _AF | _F | Any:
        if isinstance(wrapped, type):
            # Class decorator
            if not issubclass(wrapped, TestCase):
                raise TypeError("Can only decorate unittest.TestCase subclasses.")

            # Modify the setUpClass method
            orig_setUpClass = wrapped.setUpClass.__func__  # type: ignore[attr-defined]

            @functools.wraps(orig_setUpClass)
            def setUpClass(cls: type[TestCase]) -> None:
                self.__enter__()
                try:
                    orig_setUpClass(cls)
                except Exception:
                    self.__exit__(*sys.exc_info())
                    raise

            wrapped.setUpClass = classmethod(setUpClass)  # type: ignore[assignment]

            orig_tearDownClass = (
                wrapped.tearDownClass.__func__  # type: ignore[attr-defined]
            )

            @functools.wraps(orig_tearDownClass)
            def tearDownClass(cls: type[TestCase]) -> None:
                orig_tearDownClass(cls)
                self.__exit__(None, None, None)

            wrapped.tearDownClass = classmethod(  # type: ignore[assignment]
                tearDownClass
            )
            return cast(TestCaseType, wrapped)
        elif inspect.iscoroutinefunction(wrapped):

            @functools.wraps(wrapped)
            async def wrapper(*args: Any, **kwargs: Any) -> Any:
                with self:
                    return await wrapped(*args, **kwargs)

            return cast(_AF, wrapper)
        else:
            assert callable(wrapped)

            @functools.wraps(wrapped)
            def wrapper(*args: Any, **kwargs: Any) -> Any:
                with self:
                    return wrapped(*args, **kwargs)

            return cast(_F, wrapper)


# datetime module


def now(tz: dt.tzinfo | None = None) -> dt.datetime:
    return dt.datetime.fromtimestamp(time(), tz)


def utcnow() -> dt.datetime:
    return dt.datetime.fromtimestamp(time(), dt.timezone.utc).replace(tzinfo=None)


# time module


def clock_gettime(clk_id: int) -> float:
    if clk_id != CLOCK_REALTIME:
        result: float = _time_machine.original_clock_gettime(clk_id)
        return result
    return time()


def clock_gettime_ns(clk_id: int) -> int:
    if clk_id != CLOCK_REALTIME:
        result: int = _time_machine.original_clock_gettime_ns(clk_id)
        return result
    return time_ns()


def gmtime(secs: float | None = None) -> struct_time:
    result: struct_time
    if secs is not None:
        result = _time_machine.original_gmtime(secs)
    else:
        result = _time_machine.original_gmtime(traveller_stack[-1].time())
    return result


def localtime(secs: float | None = None) -> struct_time:
    result: struct_time
    if secs is not None:
        result = _time_machine.original_localtime(secs)
    else:
        result = _time_machine.original_localtime(traveller_stack[-1].time())
    return result


def strftime(format: str, t: _TimeTuple | struct_time | None = None) -> str:
    result: str
    if t is not None:
        result = _time_machine.original_strftime(format, t)
    else:
        result = _time_machine.original_strftime(format, localtime())
    return result


def time() -> float:
    return traveller_stack[-1].time()


def time_ns() -> int:
    return traveller_stack[-1].time_ns()


# pytest plugin

if HAVE_PYTEST:  # pragma: no branch

    def pytest_collection_modifyitems(items: list[pytest.Item]) -> None:
        """
        Add the fixture to any tests with the marker.
        """
        for item in items:
            if item.get_closest_marker("time_machine"):
                item.fixturenames.insert(0, "time_machine")  # type: ignore[attr-defined]

    def pytest_configure(config: pytest.Config) -> None:
        """
        Register the marker.
        """
        config.addinivalue_line(
            "markers", "time_machine(...): set the time with time-machine"
        )

    class TimeMachineFixture:
        traveller: travel | None
        traveller_obj: Traveller | None

        def __init__(self) -> None:
            self.traveller = None
            self.traveller_obj = None

        def move_to(
            self,
            destination: DestinationType,
            tick: bool | None = None,
        ) -> None:
            if self.traveller is None:
                if tick is None:
                    tick = True
                self.traveller = travel(destination, tick=tick)
                self.traveller_obj = self.traveller.start()
            else:
                assert self.traveller_obj is not None
                self.traveller_obj.move_to(destination, tick=tick)

        def shift(self, delta: dt.timedelta | int | float) -> None:
            if self.traveller is None:
                raise RuntimeError(
                    "Initialize time_machine with move_to() before using shift()."
                )
            assert self.traveller_obj is not None
            self.traveller_obj.shift(delta=delta)

        def stop(self) -> None:
            if self.traveller is not None:
                self.traveller.stop()

    @pytest.fixture(name="time_machine")
    def time_machine_fixture(
        request: pytest.FixtureRequest,
    ) -> TypingGenerator[TimeMachineFixture, None, None]:
        fixture = TimeMachineFixture()
        marker = request.node.get_closest_marker("time_machine")
        if marker:
            fixture.move_to(*marker.args, **marker.kwargs)

        yield fixture
        fixture.stop()


# escape hatch


class _EscapeHatchDatetimeDatetime:
    def now(self, tz: dt.tzinfo | None = None) -> dt.datetime:
        result: dt.datetime = _time_machine.original_now(tz)
        return result

    def utcnow(self) -> dt.datetime:
        result: dt.datetime = _time_machine.original_utcnow()
        return result


class _EscapeHatchDatetime:
    def __init__(self) -> None:
        self.datetime = _EscapeHatchDatetimeDatetime()


class _EscapeHatchTime:
    def clock_gettime(self, clk_id: int) -> float:
        result: float = _time_machine.original_clock_gettime(clk_id)
        return result

    def clock_gettime_ns(self, clk_id: int) -> int:
        result: int = _time_machine.original_clock_gettime_ns(clk_id)
        return result

    def gmtime(self, secs: float | None = None) -> struct_time:
        result: struct_time = _time_machine.original_gmtime(secs)
        return result

    def localtime(self, secs: float | None = None) -> struct_time:
        result: struct_time = _time_machine.original_localtime(secs)
        return result

    def strftime(self, format: str, t: _TimeTuple | struct_time | None = None) -> str:
        result: str
        if t is not None:
            result = _time_machine.original_strftime(format, t)
        else:
            result = _time_machine.original_strftime(format)
        return result

    def time(self) -> float:
        result: float = _time_machine.original_time()
        return result

    def time_ns(self) -> int:
        result: int = _time_machine.original_time_ns()
        return result


class _EscapeHatch:
    def __init__(self) -> None:
        self.datetime = _EscapeHatchDatetime()
        self.time = _EscapeHatchTime()

    def is_travelling(self) -> bool:
        return bool(traveller_stack)


escape_hatch = _EscapeHatch()
